import sys
from turtle import forward
from typing import Dict
sys.path.insert(0, '../Utilities/')

import oneflow as flow
from collections import OrderedDict

import numpy as np
import matplotlib.pyplot as plt
#import scipy.io
#from scipy.interpolate import griddata
#from plotting import newfig, savefig
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits import mplot3d
import matplotlib.gridspec as gridspec
#import time
import sys
sys.path.append("..")
from data.geometry_2d import Polygon

np.random.seed(1234)


# polygon vertices
geom = Polygon([[0, 0], [1, 0], [1, -1], [-1, -1], [-1, 1], [0, 1]])

# generate train data
x_train_domain = geom.random_points(1200)          # domain  points
x_train_bc = geom.random_boundary_points(120)      # boundary points     

# generate test data
x_test = geom.random_points(1500)
#print(x_train_domain.shape, x_train_bc.shape, x_test.shape)

# CUDA support 
if flow.cuda.is_available():
    device = flow.device('cuda')
else:
    device = flow.device('cpu')

# the deep neural network
class DNN(flow.nn.Module):
    def __init__(self, layers):
        super(DNN, self).__init__()
        
        # parameters
        self.depth = len(layers) - 1
        
        # set up layer order dict
        self.activation = flow.nn.Tanh
        
        layer_list = list()
        for i in range(self.depth - 1): 
            layer_list.append(
                ('layer_%d' % i, flow.nn.Linear(layers[i], layers[i+1]))
            )
            layer_list.append(('activation_%d' % i, self.activation()))
            
        layer_list.append(
            ('layer_%d' % (self.depth - 1), flow.nn.Linear(layers[-2], layers[-1]))
        )
        layerDict = OrderedDict(layer_list)
        
        # deploy layers
        self.layers = flow.nn.Sequential(layerDict)
        
    def forward(self, x):
        out = self.layers(x)
        return out


class Net_U_F(flow.nn.Module):
    """ autograd version of calculating residual """
    def __init__(self, layers) :    
        super(Net_U_F, self).__init__()
        
        # create net
        self.dnn = DNN(layers).to(device)
        
    def forward(self, x, y):
        u_pred = self.dnn(flow.cat([x,y],dim=1))
        u_x = flow.autograd.grad(
            u_pred, x, 
            grad_outputs=flow.ones_like(u_pred),
            retain_graph=True,
            create_graph=True
        )[0]
        u_y = flow.autograd.grad(
            u_pred, y, 
            grad_outputs=flow.ones_like(u_pred),
            retain_graph=True,
            create_graph=True
        )[0]
        u_xx = flow.autograd.grad(
            u_x, x, 
            grad_outputs=flow.ones_like(u_x),
            retain_graph=True,
            create_graph=True
        )[0]
        u_yy = flow.autograd.grad(
            u_y, y, 
            grad_outputs=flow.ones_like(u_y),
            retain_graph=True,
            create_graph=True
        )[0]
        # condition -> 0
        f_pred = 1+ u_xx + u_yy
        return (u_pred, f_pred)


# the physics-guided neural network
class PhysicsInformedNN(flow.nn.Module):
    def __init__(self, layers):
        super(PhysicsInformedNN, self).__init__()
        self.net_u_f = Net_U_F(layers)
        
    def forward(self, X_bc, X_domain):
        self.net_u_f.train()
        x_bc = flow.tensor(X_bc[:, 0:1], requires_grad=True).float().to(device)
        y_bc = flow.tensor(X_bc[:, 1:2], requires_grad=True).float().to(device)
        u_bc_pred,f_bc_pred = self.net_u_f(x_bc, y_bc)
        
        self.net_u_f.train()
        x_dm = flow.tensor(X_domain[:, 0:1], requires_grad=True).float().to(device)
        y_dm = flow.tensor(X_domain[:, 1:2], requires_grad=True).float().to(device)
        _,f_dm_pred = self.net_u_f(x_dm, y_dm)
        
        loss = flow.mean((u_bc_pred)**2) + flow.mean(flow.cat([f_bc_pred, f_dm_pred], dim=0)**2)     
        return loss 
    
    def eval(self, X):
        self.net_u_f.eval()
        x = flow.tensor(X[:, 0:1], requires_grad=True).float().to(device)
        y = flow.tensor(X[:, 1:2], requires_grad=True).float().to(device)
        u,f = self.net_u_f(x,y)
        u = u.detach().cpu().numpy()
        f = f.detach().cpu().numpy()      
        return u, f  
     

# layer
layers = [2, 50, 50, 50, 50, 1]


# train net
model = PhysicsInformedNN(layers)
optimizer = flow.optim.Adam(model.parameters(), lr=1e-4)

for i in range(1, 2000):
    weights = model.state_dict()
    loss = model(x_train_domain, x_train_bc)
    
    if i % 1 ==0:
        print(
            '========step: %d, Loss: %e' %
            (
                i,
                loss
            ), flush=True
        )
    if i % 1000 == 0:
        u_pred, f_pred = model.eval(x_test)
        error_u = np.linalg.norm(f_pred,2)
        #error_u = np.linalg.norm(test_y-u_pred,2)/np.linalg.norm(test_y,2)
        print('Residual u: %e' % (error_u))
        print('save step', i)
        flow.save(model.state_dict(), "./model"+str(i))
        
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    

''' plot figure of equation'''
#fig, ax = plt.subplot()
# h = ax.imshow(U_pred.T, interpolation='nearest', cmap='rainbow',
#               extent=[t.min().item(), t.max().item(), x.min().item(), x.max().item()],
#               origin='lower', aspect='auto')
fig,ax1 = plt.subplots()
points = np.array([[0, 0], [1, 0], [1, -1], [-1, -1], [-1, 1], [0, 1]])
ax1.fill(points[:,0],points[:,1],'c',alpha=0.5)
ax1.scatter(points[:,0],points[:,1],s=50)
ax1.set_title("The domain of Poisson function defined")
ax1.set_xlabel("X")
ax1.set_ylabel("Y")

f = plt.figure()
ax2 = plt.gca(projection='3d')
xytck = np.arange(-1,1.5,0.5)
ztck = np.arange(-0.5,0.5,0.05)


predict,_ = model.eval(x_test)
xdata = x_test[:,0:1].reshape(1500)
ydata = x_test[:,1:2].reshape(1500)
predict = predict.reshape(1500)
#ax2.scatter3D(xdata, ydata, predict, cmap='Greens')
ax2.plot_trisurf(xdata, ydata, predict, linewidth=0.3, antialiased=True, cmap='rainbow', alpha=0.8)
ax2.set_xlabel('x')
ax2.set_ylabel('y')
ax2.set_zlabel('u(x,y)')


plt.show()